Eager Few Shot Object Detection Colab

Welcome to the Eager Few Shot Object Detection Colab --- in this colab we demonstrate fine tuning of a (TF2 friendly) RetinaNet architecture on very few examples of a novel class after initializing from a pre-trained COCO checkpoint. Training runs in eager mode.

Estimated time to run through this colab (with GPU): < 5 minutes.

Imports

In [1]:
!pip install -U --pre tensorflow=="2.2.0"
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
ERROR: Could not find a version that satisfies the requirement tensorflow==2.2.0 (from versions: 2.5.0, 2.5.1, 2.5.2, 2.5.3, 2.6.0rc0, 2.6.0rc1, 2.6.0rc2, 2.6.0, 2.6.1, 2.6.2, 2.6.3, 2.6.4, 2.6.5, 2.7.0rc0, 2.7.0rc1, 2.7.0, 2.7.1, 2.7.2, 2.7.3, 2.7.4, 2.8.0rc0, 2.8.0rc1, 2.8.0, 2.8.1, 2.8.2, 2.8.3, 2.8.4, 2.9.0rc0, 2.9.0rc1, 2.9.0rc2, 2.9.0, 2.9.1, 2.9.2, 2.9.3, 2.10.0rc0, 2.10.0rc1, 2.10.0rc2, 2.10.0rc3, 2.10.0, 2.10.1, 2.11.0rc0, 2.11.0rc1, 2.11.0rc2, 2.11.0, 2.11.1, 2.12.0rc0, 2.12.0rc1, 2.12.0)
ERROR: No matching distribution found for tensorflow==2.2.0
In [2]:
import os
import pathlib

# Clone the tensorflow models repository if it doesn't already exist
if "models" in pathlib.Path.cwd().parts:
  while "models" in pathlib.Path.cwd().parts:
    os.chdir('..')
elif not pathlib.Path('models').exists():
  !git clone --depth 1 https://github.com/tensorflow/models
Cloning into 'models'...
remote: Enumerating objects: 3733, done.
remote: Counting objects: 100% (3733/3733), done.
remote: Compressing objects: 100% (2872/2872), done.
remote: Total 3733 (delta 1064), reused 1801 (delta 811), pack-reused 0
Receiving objects: 100% (3733/3733), 48.75 MiB | 18.05 MiB/s, done.
Resolving deltas: 100% (1064/1064), done.
Updating files: 100% (3393/3393), done.
In [3]:
# Install the Object Detection API
%%bash
cd models/research/
protoc object_detection/protos/*.proto --python_out=.
cp object_detection/packages/tf2/setup.py .
python -m pip install .
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Processing /content/models/research
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Collecting avro-python3
  Downloading avro-python3-1.10.2.tar.gz (38 kB)
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Collecting apache-beam
  Downloading apache_beam-2.46.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (14.5 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 14.5/14.5 MB 80.6 MB/s eta 0:00:00
Requirement already satisfied: pillow in /usr/local/lib/python3.9/dist-packages (from object-detection==0.1) (8.4.0)
Requirement already satisfied: lxml in /usr/local/lib/python3.9/dist-packages (from object-detection==0.1) (4.9.2)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.9/dist-packages (from object-detection==0.1) (3.7.1)
Requirement already satisfied: Cython in /usr/local/lib/python3.9/dist-packages (from object-detection==0.1) (0.29.34)
Requirement already satisfied: contextlib2 in /usr/local/lib/python3.9/dist-packages (from object-detection==0.1) (0.6.0.post1)
Requirement already satisfied: tf-slim in /usr/local/lib/python3.9/dist-packages (from object-detection==0.1) (1.1.0)
Requirement already satisfied: six in /usr/local/lib/python3.9/dist-packages (from object-detection==0.1) (1.16.0)
Requirement already satisfied: pycocotools in /usr/local/lib/python3.9/dist-packages (from object-detection==0.1) (2.0.6)
Collecting lvis
  Downloading lvis-0.5.3-py3-none-any.whl (14 kB)
Requirement already satisfied: scipy in /usr/local/lib/python3.9/dist-packages (from object-detection==0.1) (1.10.1)
Requirement already satisfied: pandas in /usr/local/lib/python3.9/dist-packages (from object-detection==0.1) (1.5.3)
Collecting tf-models-official>=2.5.1
  Downloading tf_models_official-2.12.0-py2.py3-none-any.whl (2.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.6/2.6 MB 82.1 MB/s eta 0:00:00
Collecting tensorflow_io
  Downloading tensorflow_io-0.32.0-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (28.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 28.0/28.0 MB 50.2 MB/s eta 0:00:00
Requirement already satisfied: keras in /usr/local/lib/python3.9/dist-packages (from object-detection==0.1) (2.12.0)
Collecting pyparsing==2.4.7
  Downloading pyparsing-2.4.7-py2.py3-none-any.whl (67 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 67.8/67.8 kB 7.7 MB/s eta 0:00:00
Collecting sacrebleu<=2.2.0
  Downloading sacrebleu-2.2.0-py3-none-any.whl (116 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 116.6/116.6 kB 13.8 MB/s eta 0:00:00
Requirement already satisfied: regex in /usr/local/lib/python3.9/dist-packages (from sacrebleu<=2.2.0->object-detection==0.1) (2022.10.31)
Collecting portalocker
  Downloading portalocker-2.7.0-py2.py3-none-any.whl (15 kB)
Requirement already satisfied: tabulate>=0.8.9 in /usr/local/lib/python3.9/dist-packages (from sacrebleu<=2.2.0->object-detection==0.1) (0.8.10)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.9/dist-packages (from sacrebleu<=2.2.0->object-detection==0.1) (1.24.3)
Collecting colorama
  Downloading colorama-0.4.6-py2.py3-none-any.whl (25 kB)
Requirement already satisfied: tensorflow-hub>=0.6.0 in /usr/local/lib/python3.9/dist-packages (from tf-models-official>=2.5.1->object-detection==0.1) (0.13.0)
Collecting pyyaml<6.0,>=5.1
  Downloading PyYAML-5.4.1-cp39-cp39-manylinux1_x86_64.whl (630 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 630.1/630.1 kB 49.1 MB/s eta 0:00:00
Collecting tensorflow-model-optimization>=0.4.1
  Downloading tensorflow_model_optimization-0.7.4-py2.py3-none-any.whl (240 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 240.6/240.6 kB 25.3 MB/s eta 0:00:00
Collecting immutabledict
  Downloading immutabledict-2.2.4-py3-none-any.whl (4.1 kB)
Requirement already satisfied: gin-config in /usr/local/lib/python3.9/dist-packages (from tf-models-official>=2.5.1->object-detection==0.1) (0.5.0)
Requirement already satisfied: kaggle>=1.3.9 in /usr/local/lib/python3.9/dist-packages (from tf-models-official>=2.5.1->object-detection==0.1) (1.5.13)
Requirement already satisfied: py-cpuinfo>=3.3.0 in /usr/local/lib/python3.9/dist-packages (from tf-models-official>=2.5.1->object-detection==0.1) (9.0.0)
Requirement already satisfied: tensorflow-datasets in /usr/local/lib/python3.9/dist-packages (from tf-models-official>=2.5.1->object-detection==0.1) (4.8.3)
Collecting tensorflow-text~=2.12.0
  Downloading tensorflow_text-2.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.0/6.0 MB 98.5 MB/s eta 0:00:00
Requirement already satisfied: tensorflow~=2.12.0 in /usr/local/lib/python3.9/dist-packages (from tf-models-official>=2.5.1->object-detection==0.1) (2.12.0)
Collecting sentencepiece
  Downloading sentencepiece-0.1.98-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 70.2 MB/s eta 0:00:00
Collecting tensorflow-addons
  Downloading tensorflow_addons-0.20.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (591 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 591.0/591.0 kB 47.6 MB/s eta 0:00:00
Requirement already satisfied: psutil>=5.4.3 in /usr/local/lib/python3.9/dist-packages (from tf-models-official>=2.5.1->object-detection==0.1) (5.9.5)
Requirement already satisfied: oauth2client in /usr/local/lib/python3.9/dist-packages (from tf-models-official>=2.5.1->object-detection==0.1) (4.1.3)
Collecting seqeval
  Downloading seqeval-1.2.2.tar.gz (43 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 43.6/43.6 kB 4.7 MB/s eta 0:00:00
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Requirement already satisfied: opencv-python-headless in /usr/local/lib/python3.9/dist-packages (from tf-models-official>=2.5.1->object-detection==0.1) (4.7.0.72)
Requirement already satisfied: google-api-python-client>=1.6.7 in /usr/local/lib/python3.9/dist-packages (from tf-models-official>=2.5.1->object-detection==0.1) (2.84.0)
Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.9/dist-packages (from pandas->object-detection==0.1) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.9/dist-packages (from pandas->object-detection==0.1) (2022.7.1)
Requirement already satisfied: absl-py>=0.2.2 in /usr/local/lib/python3.9/dist-packages (from tf-slim->object-detection==0.1) (1.4.0)
Collecting pymongo<4.0.0,>=3.8.0
  Downloading pymongo-3.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (515 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 515.5/515.5 kB 34.8 MB/s eta 0:00:00
Requirement already satisfied: cloudpickle~=2.2.1 in /usr/local/lib/python3.9/dist-packages (from apache-beam->object-detection==0.1) (2.2.1)
Collecting fastavro<2,>=0.23.6
  Downloading fastavro-1.7.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.7/2.7 MB 64.2 MB/s eta 0:00:00
Collecting zstandard<1,>=0.18.0
  Downloading zstandard-0.21.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.7/2.7 MB 68.0 MB/s eta 0:00:00
Requirement already satisfied: httplib2<0.22.0,>=0.8 in /usr/local/lib/python3.9/dist-packages (from apache-beam->object-detection==0.1) (0.21.0)
Requirement already satisfied: pyarrow<10.0.0,>=3.0.0 in /usr/local/lib/python3.9/dist-packages (from apache-beam->object-detection==0.1) (9.0.0)
Requirement already satisfied: proto-plus<2,>=1.7.1 in /usr/local/lib/python3.9/dist-packages (from apache-beam->object-detection==0.1) (1.22.2)
Collecting hdfs<3.0.0,>=2.1.0
  Downloading hdfs-2.7.0-py3-none-any.whl (34 kB)
Collecting dill<0.3.2,>=0.3.1.1
  Downloading dill-0.3.1.1.tar.gz (151 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 152.0/152.0 kB 14.7 MB/s eta 0:00:00
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Collecting objsize<0.7.0,>=0.6.1
  Downloading objsize-0.6.1-py3-none-any.whl (9.3 kB)
Collecting crcmod<2.0,>=1.7
  Downloading crcmod-1.7.tar.gz (89 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 89.7/89.7 kB 8.6 MB/s eta 0:00:00
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Requirement already satisfied: requests<3.0.0,>=2.24.0 in /usr/local/lib/python3.9/dist-packages (from apache-beam->object-detection==0.1) (2.27.1)
Requirement already satisfied: protobuf<4,>3.12.2 in /usr/local/lib/python3.9/dist-packages (from apache-beam->object-detection==0.1) (3.20.3)
Collecting fasteners<1.0,>=0.3
  Downloading fasteners-0.18-py3-none-any.whl (18 kB)
Requirement already satisfied: typing-extensions>=3.7.0 in /usr/local/lib/python3.9/dist-packages (from apache-beam->object-detection==0.1) (4.5.0)
Collecting orjson<4.0
  Downloading orjson-3.8.10-cp39-cp39-manylinux_2_28_x86_64.whl (140 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 140.5/140.5 kB 11.4 MB/s eta 0:00:00
Requirement already satisfied: pydot<2,>=1.2.0 in /usr/local/lib/python3.9/dist-packages (from apache-beam->object-detection==0.1) (1.4.2)
Requirement already satisfied: grpcio!=1.48.0,<2,>=1.33.1 in /usr/local/lib/python3.9/dist-packages (from apache-beam->object-detection==0.1) (1.54.0)
Requirement already satisfied: cycler>=0.10.0 in /usr/local/lib/python3.9/dist-packages (from lvis->object-detection==0.1) (0.11.0)
Requirement already satisfied: opencv-python>=4.1.0.25 in /usr/local/lib/python3.9/dist-packages (from lvis->object-detection==0.1) (4.7.0.72)
Requirement already satisfied: kiwisolver>=1.1.0 in /usr/local/lib/python3.9/dist-packages (from lvis->object-detection==0.1) (1.4.4)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.9/dist-packages (from matplotlib->object-detection==0.1) (4.39.3)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.9/dist-packages (from matplotlib->object-detection==0.1) (23.1)
Requirement already satisfied: importlib-resources>=3.2.0 in /usr/local/lib/python3.9/dist-packages (from matplotlib->object-detection==0.1) (5.12.0)
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.9/dist-packages (from matplotlib->object-detection==0.1) (1.0.7)
Requirement already satisfied: tensorflow-io-gcs-filesystem==0.32.0 in /usr/local/lib/python3.9/dist-packages (from tensorflow_io->object-detection==0.1) (0.32.0)
Requirement already satisfied: google-auth<3.0.0dev,>=1.19.0 in /usr/local/lib/python3.9/dist-packages (from google-api-python-client>=1.6.7->tf-models-official>=2.5.1->object-detection==0.1) (2.17.3)
Requirement already satisfied: google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0dev,>=1.31.5 in /usr/local/lib/python3.9/dist-packages (from google-api-python-client>=1.6.7->tf-models-official>=2.5.1->object-detection==0.1) (2.11.0)
Requirement already satisfied: google-auth-httplib2>=0.1.0 in /usr/local/lib/python3.9/dist-packages (from google-api-python-client>=1.6.7->tf-models-official>=2.5.1->object-detection==0.1) (0.1.0)
Requirement already satisfied: uritemplate<5,>=3.0.1 in /usr/local/lib/python3.9/dist-packages (from google-api-python-client>=1.6.7->tf-models-official>=2.5.1->object-detection==0.1) (4.1.1)
Collecting docopt
  Downloading docopt-0.6.2.tar.gz (25 kB)
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Requirement already satisfied: zipp>=3.1.0 in /usr/local/lib/python3.9/dist-packages (from importlib-resources>=3.2.0->matplotlib->object-detection==0.1) (3.15.0)
Requirement already satisfied: urllib3 in /usr/local/lib/python3.9/dist-packages (from kaggle>=1.3.9->tf-models-official>=2.5.1->object-detection==0.1) (1.26.15)
Requirement already satisfied: certifi in /usr/local/lib/python3.9/dist-packages (from kaggle>=1.3.9->tf-models-official>=2.5.1->object-detection==0.1) (2022.12.7)
Requirement already satisfied: tqdm in /usr/local/lib/python3.9/dist-packages (from kaggle>=1.3.9->tf-models-official>=2.5.1->object-detection==0.1) (4.65.0)
Requirement already satisfied: python-slugify in /usr/local/lib/python3.9/dist-packages (from kaggle>=1.3.9->tf-models-official>=2.5.1->object-detection==0.1) (8.0.1)
Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.9/dist-packages (from requests<3.0.0,>=2.24.0->apache-beam->object-detection==0.1) (2.0.12)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.9/dist-packages (from requests<3.0.0,>=2.24.0->apache-beam->object-detection==0.1) (3.4)
Requirement already satisfied: flatbuffers>=2.0 in /usr/local/lib/python3.9/dist-packages (from tensorflow~=2.12.0->tf-models-official>=2.5.1->object-detection==0.1) (23.3.3)
Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.9/dist-packages (from tensorflow~=2.12.0->tf-models-official>=2.5.1->object-detection==0.1) (3.3.0)
Requirement already satisfied: google-pasta>=0.1.1 in /usr/local/lib/python3.9/dist-packages (from tensorflow~=2.12.0->tf-models-official>=2.5.1->object-detection==0.1) (0.2.0)
Requirement already satisfied: tensorflow-estimator<2.13,>=2.12.0 in /usr/local/lib/python3.9/dist-packages (from tensorflow~=2.12.0->tf-models-official>=2.5.1->object-detection==0.1) (2.12.0)
Requirement already satisfied: wrapt<1.15,>=1.11.0 in /usr/local/lib/python3.9/dist-packages (from tensorflow~=2.12.0->tf-models-official>=2.5.1->object-detection==0.1) (1.14.1)
Requirement already satisfied: jax>=0.3.15 in /usr/local/lib/python3.9/dist-packages (from tensorflow~=2.12.0->tf-models-official>=2.5.1->object-detection==0.1) (0.3.25)
Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.9/dist-packages (from tensorflow~=2.12.0->tf-models-official>=2.5.1->object-detection==0.1) (2.3.0)
Collecting numpy>=1.17
  Downloading numpy-1.23.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 17.1/17.1 MB 77.6 MB/s eta 0:00:00
Requirement already satisfied: astunparse>=1.6.0 in /usr/local/lib/python3.9/dist-packages (from tensorflow~=2.12.0->tf-models-official>=2.5.1->object-detection==0.1) (1.6.3)
Requirement already satisfied: setuptools in /usr/local/lib/python3.9/dist-packages (from tensorflow~=2.12.0->tf-models-official>=2.5.1->object-detection==0.1) (67.7.2)
Requirement already satisfied: libclang>=13.0.0 in /usr/local/lib/python3.9/dist-packages (from tensorflow~=2.12.0->tf-models-official>=2.5.1->object-detection==0.1) (16.0.0)
Requirement already satisfied: gast<=0.4.0,>=0.2.1 in /usr/local/lib/python3.9/dist-packages (from tensorflow~=2.12.0->tf-models-official>=2.5.1->object-detection==0.1) (0.4.0)
Requirement already satisfied: h5py>=2.9.0 in /usr/local/lib/python3.9/dist-packages (from tensorflow~=2.12.0->tf-models-official>=2.5.1->object-detection==0.1) (3.8.0)
Requirement already satisfied: tensorboard<2.13,>=2.12 in /usr/local/lib/python3.9/dist-packages (from tensorflow~=2.12.0->tf-models-official>=2.5.1->object-detection==0.1) (2.12.2)
Requirement already satisfied: dm-tree~=0.1.1 in /usr/local/lib/python3.9/dist-packages (from tensorflow-model-optimization>=0.4.1->tf-models-official>=2.5.1->object-detection==0.1) (0.1.8)
Requirement already satisfied: pyasn1>=0.1.7 in /usr/local/lib/python3.9/dist-packages (from oauth2client->tf-models-official>=2.5.1->object-detection==0.1) (0.5.0)
Requirement already satisfied: pyasn1-modules>=0.0.5 in /usr/local/lib/python3.9/dist-packages (from oauth2client->tf-models-official>=2.5.1->object-detection==0.1) (0.3.0)
Requirement already satisfied: rsa>=3.1.4 in /usr/local/lib/python3.9/dist-packages (from oauth2client->tf-models-official>=2.5.1->object-detection==0.1) (4.9)
Requirement already satisfied: scikit-learn>=0.21.3 in /usr/local/lib/python3.9/dist-packages (from seqeval->tf-models-official>=2.5.1->object-detection==0.1) (1.2.2)
Collecting typeguard<3.0.0,>=2.7
  Downloading typeguard-2.13.3-py3-none-any.whl (17 kB)
Requirement already satisfied: toml in /usr/local/lib/python3.9/dist-packages (from tensorflow-datasets->tf-models-official>=2.5.1->object-detection==0.1) (0.10.2)
Requirement already satisfied: promise in /usr/local/lib/python3.9/dist-packages (from tensorflow-datasets->tf-models-official>=2.5.1->object-detection==0.1) (2.3)
Requirement already satisfied: tensorflow-metadata in /usr/local/lib/python3.9/dist-packages (from tensorflow-datasets->tf-models-official>=2.5.1->object-detection==0.1) (1.13.1)
Requirement already satisfied: etils[enp,epath]>=0.9.0 in /usr/local/lib/python3.9/dist-packages (from tensorflow-datasets->tf-models-official>=2.5.1->object-detection==0.1) (1.2.0)
Requirement already satisfied: click in /usr/local/lib/python3.9/dist-packages (from tensorflow-datasets->tf-models-official>=2.5.1->object-detection==0.1) (8.1.3)
Requirement already satisfied: wheel<1.0,>=0.23.0 in /usr/local/lib/python3.9/dist-packages (from astunparse>=1.6.0->tensorflow~=2.12.0->tf-models-official>=2.5.1->object-detection==0.1) (0.40.0)
Requirement already satisfied: googleapis-common-protos<2.0dev,>=1.56.2 in /usr/local/lib/python3.9/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0dev,>=1.31.5->google-api-python-client>=1.6.7->tf-models-official>=2.5.1->object-detection==0.1) (1.59.0)
Requirement already satisfied: cachetools<6.0,>=2.0.0 in /usr/local/lib/python3.9/dist-packages (from google-auth<3.0.0dev,>=1.19.0->google-api-python-client>=1.6.7->tf-models-official>=2.5.1->object-detection==0.1) (5.3.0)
Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.9/dist-packages (from scikit-learn>=0.21.3->seqeval->tf-models-official>=2.5.1->object-detection==0.1) (3.1.0)
Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.9/dist-packages (from scikit-learn>=0.21.3->seqeval->tf-models-official>=2.5.1->object-detection==0.1) (1.2.0)
Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.9/dist-packages (from tensorboard<2.13,>=2.12->tensorflow~=2.12.0->tf-models-official>=2.5.1->object-detection==0.1) (0.7.0)
Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.9/dist-packages (from tensorboard<2.13,>=2.12->tensorflow~=2.12.0->tf-models-official>=2.5.1->object-detection==0.1) (2.2.3)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.9/dist-packages (from tensorboard<2.13,>=2.12->tensorflow~=2.12.0->tf-models-official>=2.5.1->object-detection==0.1) (3.4.3)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.9/dist-packages (from tensorboard<2.13,>=2.12->tensorflow~=2.12.0->tf-models-official>=2.5.1->object-detection==0.1) (1.8.1)
Requirement already satisfied: google-auth-oauthlib<1.1,>=0.5 in /usr/local/lib/python3.9/dist-packages (from tensorboard<2.13,>=2.12->tensorflow~=2.12.0->tf-models-official>=2.5.1->object-detection==0.1) (1.0.0)
Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.9/dist-packages (from python-slugify->kaggle>=1.3.9->tf-models-official>=2.5.1->object-detection==0.1) (1.3)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.9/dist-packages (from google-auth-oauthlib<1.1,>=0.5->tensorboard<2.13,>=2.12->tensorflow~=2.12.0->tf-models-official>=2.5.1->object-detection==0.1) (1.3.1)
Requirement already satisfied: importlib-metadata>=4.4 in /usr/local/lib/python3.9/dist-packages (from markdown>=2.6.8->tensorboard<2.13,>=2.12->tensorflow~=2.12.0->tf-models-official>=2.5.1->object-detection==0.1) (6.6.0)
Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.9/dist-packages (from werkzeug>=1.0.1->tensorboard<2.13,>=2.12->tensorflow~=2.12.0->tf-models-official>=2.5.1->object-detection==0.1) (2.1.2)
Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.9/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<1.1,>=0.5->tensorboard<2.13,>=2.12->tensorflow~=2.12.0->tf-models-official>=2.5.1->object-detection==0.1) (3.2.2)
Building wheels for collected packages: object-detection, avro-python3, crcmod, dill, seqeval, docopt
  Building wheel for object-detection (setup.py): started
  Building wheel for object-detection (setup.py): finished with status 'done'
  Created wheel for object-detection: filename=object_detection-0.1-py3-none-any.whl size=1697012 sha256=265c01204289cc2925b9386cd706cdec9cfaa1b4288a6473f25b4fc58ca95f35
  Stored in directory: /tmp/pip-ephem-wheel-cache-9pe5u1_n/wheels/1d/d6/23/8af3ea45a1048e13ebe55b2675eecc8eb39c9af79b37e078ae
  Building wheel for avro-python3 (setup.py): started
  Building wheel for avro-python3 (setup.py): finished with status 'done'
  Created wheel for avro-python3: filename=avro_python3-1.10.2-py3-none-any.whl size=44008 sha256=e4b8c81207bd1c163f3e2129d6c696af9c8d595d20684609aaf9877b2b799f2f
  Stored in directory: /root/.cache/pip/wheels/5a/29/4d/510c0e098c49c5e49519f430481a5425e60b8752682d7b1e55
  Building wheel for crcmod (setup.py): started
  Building wheel for crcmod (setup.py): finished with status 'done'
  Created wheel for crcmod: filename=crcmod-1.7-cp39-cp39-linux_x86_64.whl size=36921 sha256=959f420ba80feb2ce7d097c3631d60e3d2b778c1ae49024b57f7c69b212f8326
  Stored in directory: /root/.cache/pip/wheels/4a/6c/a6/ffdd136310039bf226f2707a9a8e6857be7d70a3fc061f6b36
  Building wheel for dill (setup.py): started
  Building wheel for dill (setup.py): finished with status 'done'
  Created wheel for dill: filename=dill-0.3.1.1-py3-none-any.whl size=78545 sha256=a06e7ae918cfbf79f5324ec8a0c05067535804a3ff784e4f74a29a2ac658cc42
  Stored in directory: /root/.cache/pip/wheels/4f/0b/ce/75d96dd714b15e51cb66db631183ea3844e0c4a6d19741a149
  Building wheel for seqeval (setup.py): started
  Building wheel for seqeval (setup.py): finished with status 'done'
  Created wheel for seqeval: filename=seqeval-1.2.2-py3-none-any.whl size=16180 sha256=ce8d4d2964f9c97e24a3c5509a5a6dc3a9ca512d2086d4fdf9c4e73536d16712
  Stored in directory: /root/.cache/pip/wheels/e2/a5/92/2c80d1928733611c2747a9820e1324a6835524d9411510c142
  Building wheel for docopt (setup.py): started
  Building wheel for docopt (setup.py): finished with status 'done'
  Created wheel for docopt: filename=docopt-0.6.2-py2.py3-none-any.whl size=13721 sha256=4ff5b88453baf4fff41f98dc13a637419110d8ebe2579b1e25aa7d2012fe5304
  Stored in directory: /root/.cache/pip/wheels/70/4a/46/1309fc853b8d395e60bafaf1b6df7845bdd82c95fd59dd8d2b
Successfully built object-detection avro-python3 crcmod dill seqeval docopt
Installing collected packages: sentencepiece, docopt, crcmod, zstandard, typeguard, tensorflow_io, pyyaml, pyparsing, pymongo, portalocker, orjson, objsize, numpy, immutabledict, fasteners, fastavro, dill, colorama, avro-python3, tensorflow-model-optimization, tensorflow-addons, sacrebleu, hdfs, apache-beam, seqeval, lvis, tensorflow-text, tf-models-official, object-detection
  Attempting uninstall: pyyaml
    Found existing installation: PyYAML 6.0
    Uninstalling PyYAML-6.0:
      Successfully uninstalled PyYAML-6.0
  Attempting uninstall: pyparsing
    Found existing installation: pyparsing 3.0.9
    Uninstalling pyparsing-3.0.9:
      Successfully uninstalled pyparsing-3.0.9
  Attempting uninstall: numpy
    Found existing installation: numpy 1.24.3
    Uninstalling numpy-1.24.3:
      Successfully uninstalled numpy-1.24.3
Successfully installed apache-beam-2.46.0 avro-python3-1.10.2 colorama-0.4.6 crcmod-1.7 dill-0.3.1.1 docopt-0.6.2 fastavro-1.7.3 fasteners-0.18 hdfs-2.7.0 immutabledict-2.2.4 lvis-0.5.3 numpy-1.23.5 object-detection-0.1 objsize-0.6.1 orjson-3.8.10 portalocker-2.7.0 pymongo-3.13.0 pyparsing-2.4.7 pyyaml-5.4.1 sacrebleu-2.2.0 sentencepiece-0.1.98 seqeval-1.2.2 tensorflow-addons-0.20.0 tensorflow-model-optimization-0.7.4 tensorflow-text-2.12.1 tensorflow_io-0.32.0 tf-models-official-2.12.0 typeguard-2.13.3 zstandard-0.21.0
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
orbax-checkpoint 0.2.1 requires jax>=0.4.8, but you have jax 0.3.25 which is incompatible.
flax 0.6.9 requires jax>=0.4.2, but you have jax 0.3.25 which is incompatible.
chex 0.1.7 requires jax>=0.4.6, but you have jax 0.3.25 which is incompatible.
In [4]:
import matplotlib
import matplotlib.pyplot as plt

import os
import random
import io
import imageio
import glob
import scipy.misc
import numpy as np
from six import BytesIO
from PIL import Image, ImageDraw, ImageFont
from IPython.display import display, Javascript
from IPython.display import Image as IPyImage

import tensorflow as tf

from object_detection.utils import label_map_util
from object_detection.utils import config_util
from object_detection.utils import visualization_utils as viz_utils
from object_detection.utils import colab_utils
from object_detection.builders import model_builder

%matplotlib inline
/usr/local/lib/python3.9/dist-packages/tensorflow_addons/utils/tfa_eol_msg.py:23: UserWarning: 

TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 

  warnings.warn(

Utilities

In [5]:
def load_image_into_numpy_array(path):
  """Load an image from file into a numpy array.

  Puts image into numpy array to feed into tensorflow graph.
  Note that by convention we put it into a numpy array with shape
  (height, width, channels), where channels=3 for RGB.

  Args:
    path: a file path.

  Returns:
    uint8 numpy array with shape (img_height, img_width, 3)
  """
  img_data = tf.io.gfile.GFile(path, 'rb').read()
  image = Image.open(BytesIO(img_data))
  (im_width, im_height) = image.size
  return np.array(image.getdata()).reshape(
      (im_height, im_width, 3)).astype(np.uint8)

def plot_detections(image_np,
                    boxes,
                    classes,
                    scores,
                    category_index,
                    figsize=(12, 16),
                    image_name=None):
  """Wrapper function to visualize detections.

  Args:
    image_np: uint8 numpy array with shape (img_height, img_width, 3)
    boxes: a numpy array of shape [N, 4]
    classes: a numpy array of shape [N]. Note that class indices are 1-based,
      and match the keys in the label map.
    scores: a numpy array of shape [N] or None.  If scores=None, then
      this function assumes that the boxes to be plotted are groundtruth
      boxes and plot all boxes as black with no classes or scores.
    category_index: a dict containing category dictionaries (each holding
      category index `id` and category name `name`) keyed by category indices.
    figsize: size for the figure.
    image_name: a name for the image file.
  """
  image_np_with_annotations = image_np.copy()
  viz_utils.visualize_boxes_and_labels_on_image_array(
      image_np_with_annotations,
      boxes,
      classes,
      scores,
      category_index,
      use_normalized_coordinates=True,
      min_score_thresh=0.8)
  if image_name:
    plt.imsave(image_name, image_np_with_annotations)
  else:
    plt.imshow(image_np_with_annotations)

Rubber Ducky data

We will start with some toy (literally) data consisting of 5 images of a rubber ducky. Note that the coco dataset contains a number of animals, but notably, it does not contain rubber duckies (or even ducks for that matter), so this is a novel class.

In [6]:
# Load images and visualize
train_image_dir = 'models/research/object_detection/test_images/ducky/train/'
train_images_np = []
for i in range(1, 6):
  image_path = os.path.join(train_image_dir, 'robertducky' + str(i) + '.jpg')
  train_images_np.append(load_image_into_numpy_array(image_path))

plt.rcParams['axes.grid'] = False
plt.rcParams['xtick.labelsize'] = False
plt.rcParams['ytick.labelsize'] = False
plt.rcParams['xtick.top'] = False
plt.rcParams['xtick.bottom'] = False
plt.rcParams['ytick.left'] = False
plt.rcParams['ytick.right'] = False
plt.rcParams['figure.figsize'] = [14, 7]

for idx, train_image_np in enumerate(train_images_np):
  plt.subplot(2, 3, idx+1)
  plt.imshow(train_image_np)
plt.show()

Annotate images with bounding boxes

In this cell you will annotate the rubber duckies --- draw a box around the rubber ducky in each image; click next image to go to the next image and submit when there are no more images.

If you'd like to skip the manual annotation step, we totally understand. In this case, simply skip this cell and run the next cell instead, where we've prepopulated the groundtruth with pre-annotated bounding boxes.

In [11]:
gt_boxes = []
colab_utils.annotate(train_images_np, box_storage_pointer=gt_boxes)
'--boxes array populated--'

In case you didn't want to label...

Run this cell only if you didn't annotate anything above and would prefer to just use our preannotated boxes. Don't forget to uncomment.

In [8]:
# gt_boxes = [
#             np.array([[0.436, 0.591, 0.629, 0.712]], dtype=np.float32),
#             np.array([[0.539, 0.583, 0.73, 0.71]], dtype=np.float32),
#             np.array([[0.464, 0.414, 0.626, 0.548]], dtype=np.float32),
#             np.array([[0.313, 0.308, 0.648, 0.526]], dtype=np.float32),
#             np.array([[0.256, 0.444, 0.484, 0.629]], dtype=np.float32)
# ]

Prepare data for training

Below we add the class annotations (for simplicity, we assume a single class in this colab; though it should be straightforward to extend this to handle multiple classes). We also convert everything to the format that the training loop below expects (e.g., everything converted to tensors, classes converted to one-hot representations, etc.).

In [12]:
# By convention, our non-background classes start counting at 1.  Given
# that we will be predicting just one class, we will therefore assign it a
# `class id` of 1.
duck_class_id = 1
num_classes = 1

category_index = {duck_class_id: {'id': duck_class_id, 'name': 'rubber_ducky'}}

# Convert class labels to one-hot; convert everything to tensors.
# The `label_id_offset` here shifts all classes by a certain number of indices;
# we do this here so that the model receives one-hot labels where non-background
# classes start counting at the zeroth index.  This is ordinarily just handled
# automatically in our training binaries, but we need to reproduce it here.
label_id_offset = 1
train_image_tensors = []
gt_classes_one_hot_tensors = []
gt_box_tensors = []
for (train_image_np, gt_box_np) in zip(
    train_images_np, gt_boxes):
  train_image_tensors.append(tf.expand_dims(tf.convert_to_tensor(
      train_image_np, dtype=tf.float32), axis=0))
  gt_box_tensors.append(tf.convert_to_tensor(gt_box_np, dtype=tf.float32))
  zero_indexed_groundtruth_classes = tf.convert_to_tensor(
      np.ones(shape=[gt_box_np.shape[0]], dtype=np.int32) - label_id_offset)
  gt_classes_one_hot_tensors.append(tf.one_hot(
      zero_indexed_groundtruth_classes, num_classes))
print('Done prepping data.')
Done prepping data.

Let's just visualize the rubber duckies as a sanity check

In [13]:
dummy_scores = np.array([1.0], dtype=np.float32)  # give boxes a score of 100%

plt.figure(figsize=(30, 15))
for idx in range(5):
  plt.subplot(2, 3, idx+1)
  plot_detections(
      train_images_np[idx],
      gt_boxes[idx],
      np.ones(shape=[gt_boxes[idx].shape[0]], dtype=np.int32),
      dummy_scores, category_index)
plt.show()

Create model and restore weights for all but last layer

In this cell we build a single stage detection architecture (RetinaNet) and restore all but the classification layer at the top (which will be automatically randomly initialized).

For simplicity, we have hardcoded a number of things in this colab for the specific RetinaNet architecture at hand (including assuming that the image size will always be 640x640), however it is not difficult to generalize to other model configurations.

In [14]:
# Download the checkpoint and put it into models/research/object_detection/test_data/

!wget http://download.tensorflow.org/models/object_detection/tf2/20200711/ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz
!tar -xf ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz
!mv ssd_resnet50_v1_fpn_640x640_coco17_tpu-8/checkpoint models/research/object_detection/test_data/
--2023-04-27 06:11:41--  http://download.tensorflow.org/models/object_detection/tf2/20200711/ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz
Resolving download.tensorflow.org (download.tensorflow.org)... 74.125.201.128, 2607:f8b0:4001:c01::80
Connecting to download.tensorflow.org (download.tensorflow.org)|74.125.201.128|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 244817203 (233M) [application/x-tar]
Saving to: ‘ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz’

ssd_resnet50_v1_fpn 100%[===================>] 233.48M   124MB/s    in 1.9s    

2023-04-27 06:11:43 (124 MB/s) - ‘ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz’ saved [244817203/244817203]

In [15]:
tf.keras.backend.clear_session()

print('Building model and restoring weights for fine-tuning...', flush=True)
num_classes = 1
pipeline_config = 'models/research/object_detection/configs/tf2/ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.config'
checkpoint_path = 'models/research/object_detection/test_data/checkpoint/ckpt-0'

# Load pipeline config and build a detection model.
#
# Since we are working off of a COCO architecture which predicts 90
# class slots by default, we override the `num_classes` field here to be just
# one (for our new rubber ducky class).
configs = config_util.get_configs_from_pipeline_file(pipeline_config)
model_config = configs['model']
model_config.ssd.num_classes = num_classes
model_config.ssd.freeze_batchnorm = True
detection_model = model_builder.build(
      model_config=model_config, is_training=True)

# Set up object-based checkpoint restore --- RetinaNet has two prediction
# `heads` --- one for classification, the other for box regression.  We will
# restore the box regression head but initialize the classification head
# from scratch (we show the omission below by commenting out the line that
# we would add if we wanted to restore both heads)
fake_box_predictor = tf.compat.v2.train.Checkpoint(
    _base_tower_layers_for_heads=detection_model._box_predictor._base_tower_layers_for_heads,
    # _prediction_heads=detection_model._box_predictor._prediction_heads,
    #    (i.e., the classification head that we *will not* restore)
    _box_prediction_head=detection_model._box_predictor._box_prediction_head,
    )
fake_model = tf.compat.v2.train.Checkpoint(
          _feature_extractor=detection_model._feature_extractor,
          _box_predictor=fake_box_predictor)
ckpt = tf.compat.v2.train.Checkpoint(model=fake_model)
ckpt.restore(checkpoint_path).expect_partial()

# Run model through a dummy image so that variables are created
image, shapes = detection_model.preprocess(tf.zeros([1, 640, 640, 3]))
prediction_dict = detection_model.predict(image, shapes)
_ = detection_model.postprocess(prediction_dict, shapes)
print('Weights restored!')
Building model and restoring weights for fine-tuning...
Weights restored!

Eager mode custom training loop

In [16]:
tf.keras.backend.set_learning_phase(True)

# These parameters can be tuned; since our training set has 5 images
# it doesn't make sense to have a much larger batch size, though we could
# fit more examples in memory if we wanted to.
batch_size = 4
learning_rate = 0.01
num_batches = 100

# Select variables in top layers to fine-tune.
trainable_variables = detection_model.trainable_variables
to_fine_tune = []
prefixes_to_train = [
  'WeightSharedConvolutionalBoxPredictor/WeightSharedConvolutionalBoxHead',
  'WeightSharedConvolutionalBoxPredictor/WeightSharedConvolutionalClassHead']
for var in trainable_variables:
  if any([var.name.startswith(prefix) for prefix in prefixes_to_train]):
    to_fine_tune.append(var)

# Set up forward + backward pass for a single train step.
def get_model_train_step_function(model, optimizer, vars_to_fine_tune):
  """Get a tf.function for training step."""

  # Use tf.function for a bit of speed.
  # Comment out the tf.function decorator if you want the inside of the
  # function to run eagerly.
  @tf.function
  def train_step_fn(image_tensors,
                    groundtruth_boxes_list,
                    groundtruth_classes_list):
    """A single training iteration.

    Args:
      image_tensors: A list of [1, height, width, 3] Tensor of type tf.float32.
        Note that the height and width can vary across images, as they are
        reshaped within this function to be 640x640.
      groundtruth_boxes_list: A list of Tensors of shape [N_i, 4] with type
        tf.float32 representing groundtruth boxes for each image in the batch.
      groundtruth_classes_list: A list of Tensors of shape [N_i, num_classes]
        with type tf.float32 representing groundtruth boxes for each image in
        the batch.

    Returns:
      A scalar tensor representing the total loss for the input batch.
    """
    shapes = tf.constant(batch_size * [[640, 640, 3]], dtype=tf.int32)
    model.provide_groundtruth(
        groundtruth_boxes_list=groundtruth_boxes_list,
        groundtruth_classes_list=groundtruth_classes_list)
    with tf.GradientTape() as tape:
      preprocessed_images = tf.concat(
          [detection_model.preprocess(image_tensor)[0]
           for image_tensor in image_tensors], axis=0)
      prediction_dict = model.predict(preprocessed_images, shapes)
      losses_dict = model.loss(prediction_dict, shapes)
      total_loss = losses_dict['Loss/localization_loss'] + losses_dict['Loss/classification_loss']
      gradients = tape.gradient(total_loss, vars_to_fine_tune)
      optimizer.apply_gradients(zip(gradients, vars_to_fine_tune))
    return total_loss

  return train_step_fn

optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9)
train_step_fn = get_model_train_step_function(
    detection_model, optimizer, to_fine_tune)

print('Start fine-tuning!', flush=True)
for idx in range(num_batches):
  # Grab keys for a random subset of examples
  all_keys = list(range(len(train_images_np)))
  random.shuffle(all_keys)
  example_keys = all_keys[:batch_size]

  # Note that we do not do data augmentation in this demo.  If you want a
  # a fun exercise, we recommend experimenting with random horizontal flipping
  # and random cropping :)
  gt_boxes_list = [gt_box_tensors[key] for key in example_keys]
  gt_classes_list = [gt_classes_one_hot_tensors[key] for key in example_keys]
  image_tensors = [train_image_tensors[key] for key in example_keys]

  # Training step (forward pass + backwards pass)
  total_loss = train_step_fn(image_tensors, gt_boxes_list, gt_classes_list)

  if idx % 10 == 0:
    print('batch ' + str(idx) + ' of ' + str(num_batches)
    + ', loss=' +  str(total_loss.numpy()), flush=True)

print('Done fine-tuning!')
Start fine-tuning!
/usr/local/lib/python3.9/dist-packages/keras/backend.py:452: UserWarning: `tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.
  warnings.warn(
batch 0 of 100, loss=1.3279171
batch 10 of 100, loss=0.15262945
batch 20 of 100, loss=0.036502436
batch 30 of 100, loss=0.019084902
batch 40 of 100, loss=0.009745872
batch 50 of 100, loss=0.0057426076
batch 60 of 100, loss=0.0046342667
batch 70 of 100, loss=0.0035587228
batch 80 of 100, loss=0.0024153707
batch 90 of 100, loss=0.002991657
Done fine-tuning!

Load test images and run inference with new model!

In [17]:
test_image_dir = 'models/research/object_detection/test_images/ducky/test/'
test_images_np = []
for i in range(1, 50):
  image_path = os.path.join(test_image_dir, 'out' + str(i) + '.jpg')
  test_images_np.append(np.expand_dims(
      load_image_into_numpy_array(image_path), axis=0))

# Again, uncomment this decorator if you want to run inference eagerly
@tf.function
def detect(input_tensor):
  """Run detection on an input image.

  Args:
    input_tensor: A [1, height, width, 3] Tensor of type tf.float32.
      Note that height and width can be anything since the image will be
      immediately resized according to the needs of the model within this
      function.

  Returns:
    A dict containing 3 Tensors (`detection_boxes`, `detection_classes`,
      and `detection_scores`).
  """
  preprocessed_image, shapes = detection_model.preprocess(input_tensor)
  prediction_dict = detection_model.predict(preprocessed_image, shapes)
  return detection_model.postprocess(prediction_dict, shapes)

# Note that the first frame will trigger tracing of the tf.function, which will
# take some time, after which inference should be fast.

label_id_offset = 1
for i in range(len(test_images_np)):
  input_tensor = tf.convert_to_tensor(test_images_np[i], dtype=tf.float32)
  detections = detect(input_tensor)

  plot_detections(
      test_images_np[i][0],
      detections['detection_boxes'][0].numpy(),
      detections['detection_classes'][0].numpy().astype(np.uint32)
      + label_id_offset,
      detections['detection_scores'][0].numpy(),
      category_index, figsize=(15, 20), image_name="gif_frame_" + ('%02d' % i) + ".jpg")
In [18]:
imageio.plugins.freeimage.download()

anim_file = 'duckies_test.gif'

filenames = glob.glob('gif_frame_*.jpg')
filenames = sorted(filenames)
last = -1
images = []
for filename in filenames:
  image = imageio.imread(filename)
  images.append(image)

imageio.mimsave(anim_file, images, 'GIF-FI', fps=5)

display(IPyImage(open(anim_file, 'rb').read()))
Imageio: 'libfreeimage-3.16.0-linux64.so' was not found on your computer; downloading it now.
Try 1. Download from https://github.com/imageio/imageio-binaries/raw/master/freeimage/libfreeimage-3.16.0-linux64.so (4.6 MB)
Downloading: 8192/4830080 bytes (0.2%)4830080/4830080 bytes (100.0%)
  Done
File saved as /root/.imageio/freeimage/libfreeimage-3.16.0-linux64.so.
<ipython-input-18-a227bebdd11f>:10: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning disappear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly.
  image = imageio.imread(filename)
In [ ]: